"""
classifier.py
"""
from numbers import Number

from numpy import where, isnan, nan, zeros

from zipline.lib.quantiles import quantiles
from zipline.pipeline.term import ComputableTerm
from zipline.utils.input_validation import expect_types
from zipline.utils.numpy_utils import int64_dtype

from ..filters import NullFilter, NumExprFilter
from ..mixins import (
    CustomTermMixin,
    LatestMixin,
    PositiveWindowLengthMixin,
    RestrictedDTypeMixin,
    SingleInputMixin,
)


class Classifier(RestrictedDTypeMixin, ComputableTerm):
    """
    A Pipeline expression computing a categorical output.

    Classifiers are most commonly useful for describing grouping keys for
    complex transformations on Factor outputs. For example, Factor.demean() and
    Factor.zscore() can be passed a Classifier in their ``groupby`` argument,
    indicating that means/standard deviations should be computed on assets for
    which the classifier produced the same label.
    """
    ALLOWED_DTYPES = (int64_dtype,)  # Used by RestrictedDTypeMixin

    def isnull(self):
        """
        A Filter producing True for values where this term has missing data.
        """
        return NullFilter(self)

    def notnull(self):
        """
        A Filter producing True for values where this term has complete data.
        """
        return ~self.isnull()

    # We explicitly don't support classifier to classifier comparisons, since
    # the numbers likely don't mean the same thing. This may be relaxed in the
    # future, but for now we're starting conservatively.
    @expect_types(other=Number)
    def eq(self, other):
        """
        Construct a Filter returning True for asset/date pairs where the output
        of ``self`` matches ``other.
        """
        # We treat this as an error because missing_values have NaN semantics,
        # which means this would return an array of all False, which is almost
        # certainly not what the user wants.
        if other == self.missing_value:
            raise ValueError(
                "Comparison against self.missing_value ({value}) in"
                " {typename}.eq().\n"
                "Missing values have NaN semantics, so the "
                "requested comparison would always produce False.\n"
                "Use the isnull() method to check for missing values.".format(
                    value=other,
                    typename=(type(self).__name__),
                )
            )
        return NumExprFilter.create(
            "x_0 == {other}".format(other=int(other)),
            binds=(self,),
        )

    @expect_types(other=Number)
    def __ne__(self, other):
        """
        Construct a Filter returning True for asset/date pairs where the output
        of ``self`` matches ``other.
        """
        return NumExprFilter.create(
            "((x_0 != {other}) & (x_0 != {missing}))".format(
                other=int(other),
                missing=self.missing_value,
            ),
            binds=(self,),
        )


class Everything(Classifier):
    """
    A trivial classifier that classifies everything the same.
    """
    dtype = int64_dtype
    window_length = 0
    inputs = ()
    missing_value = -1

    def _compute(self, arrays, dates, assets, mask):
        return where(
            mask,
            zeros(shape=mask.shape, dtype=int64_dtype),
            self.missing_value,
        )


class Quantiles(SingleInputMixin, Classifier):
    """
    A classifier computing quantiles over an input.
    """
    params = ('bins',)
    dtype = int64_dtype
    window_length = 0
    missing_value = -1

    def _compute(self, arrays, dates, assets, mask):
        data = arrays[0]
        bins = self.params['bins']
        to_bin = where(mask, data, nan)
        result = quantiles(to_bin, bins)
        # Write self.missing_value into nan locations, whether they were
        # generated by our input mask or not.
        result[isnan(result)] = self.missing_value
        return result.astype(int64_dtype)

    def short_repr(self):
        return type(self).__name__ + '(%d)' % self.params['bins']


class CustomClassifier(PositiveWindowLengthMixin, CustomTermMixin, Classifier):
    """
    Base class for user-defined Classifiers.

    See Also
    --------
    zipline.pipeline.CustomFactor
    zipline.pipeline.CustomFilter
    """
    pass


class Latest(LatestMixin, CustomClassifier):
    """
    A classifier producing the latest value of an input.

    See Also
    --------
    zipline.pipeline.data.dataset.BoundColumn.latest
    zipline.pipeline.factors.factor.Latest
    zipline.pipeline.filters.filter.Latest
    """
